# encoding: utf-8


import os
import random
import shutil


IMG_PATH = './imgs'
NORMAL_IMG_PATH = './imgs/normal_img'
MALWARE_IMG_PATH = './imgs/malware_img'

DATA_PATH = './data'
TRAIN_PATH = './data/train'
VALID_PATH = './data/valid'
NORMAL_TRAIN_PATH = './data/train/normal'
MALWARE_TRAIN_PATH = './data/train/malware'
NORMAL_VALID_PATH = './data/valid/normal'
MALWARE_VALID_PATH = './data/valid/malware'


def safe_mkdir(path):
    if os.path.exists(path):
        shutil.rmtree(path)
    os.makedirs(path)


if __name__ == '__main__':

    # 创建文件夹
    safe_mkdir(NORMAL_TRAIN_PATH)
    safe_mkdir(NORMAL_VALID_PATH)
    safe_mkdir(MALWARE_TRAIN_PATH)
    safe_mkdir(MALWARE_VALID_PATH)

    normal_imgs = os.listdir(NORMAL_IMG_PATH)
    malware_imgs = os.listdir(MALWARE_IMG_PATH)

    backdoors = []
    trojans = []
    worms = []
    exploits = []
    malware_imgs_keep = []

    for x in malware_imgs:
        if x.split('.')[0] == 'Backdoor':
            backdoors.append(x)
        elif x.split('.')[0] in  ['Trojan-Downloader', 'Trojan-GameThief', 'Trojan-PSW', 'Trojan-Spy', 'Trojan-Dropper', 'Trojan-Banker']:
            trojans.append(x)
        elif x.split('.')[0] in ['Email-Worm', 'Net-Worm', 'P2P-Worm', 'IRC-Worm', 'IM-Worm']:
            worms.append(x)
        elif x.split('.')[0] == 'Exploit':
            exploits.append(x)

    print('Backdoor num:', len(backdoors))
    print('Trojan num:', len(trojans))
    print('worm num:', len(worms))
    print('exploit num:', len(exploits))

    random.shuffle(backdoors)
    random.shuffle(trojans)
    random.shuffle(worms)
    random.shuffle(exploits)

    # 要保留的恶意软件数目与正常软件数目相同
    keep_num = len(normal_imgs)

    # 计算恶意软件总数
    total_num = len(backdoors) + len(trojans) + len(worms) + len(exploits)

    # 按比例保留恶意软件
    malware_imgs_keep.extend(backdoors[:int(keep_num * (len(backdoors) / total_num))])
    malware_imgs_keep.extend(trojans[:int(keep_num * (len(trojans) / total_num))])
    malware_imgs_keep.extend(worms[:int(keep_num * (len(worms) / total_num))])
    malware_imgs_keep.extend(exploits[:int(keep_num * (len(exploits) / total_num))])

    malware_imgs = malware_imgs_keep
    print('正常软件的数量:', len(normal_imgs))
    print('恶意软件的数量:', len(malware_imgs))

    # shuffle
    print('normal_imgs[:10]: ', normal_imgs[:10])
    print('malware_imgs[:10]: ', malware_imgs[:10])
    random.shuffle(normal_imgs)
    random.shuffle(malware_imgs)
    print('normal_imgs[:10]: ', normal_imgs[:10])
    print('malware_imgs[:10]: ', malware_imgs[:10])

    # 划分正常软件，80% 划入训练集，20% 划入验证集
    for img in normal_imgs[0: int(0.8*len(normal_imgs))]:
        source = os.path.join(NORMAL_IMG_PATH, img)
        target = os.path.join(NORMAL_TRAIN_PATH, img)
        shutil.copy(source, target)

    for img in normal_imgs[int(0.8*len(normal_imgs)):]:
        source = os.path.join(NORMAL_IMG_PATH, img)
        target = os.path.join(NORMAL_VALID_PATH, img)
        shutil.copy(source, target)

    # 划分恶意软件（与正常软件同数量，同样的划分比例）
    for img in malware_imgs[0: int(len(normal_imgs) * 0.8)]:
        source = os.path.join(MALWARE_IMG_PATH, img)
        target = os.path.join(MALWARE_TRAIN_PATH, img)
        shutil.copy(source, target)

    for img in malware_imgs[int(len(normal_imgs) * 0.8): len(normal_imgs)]:
        source = os.path.join(MALWARE_IMG_PATH, img)
        target = os.path.join(MALWARE_VALID_PATH, img)
        shutil.copy(source, target)

